Aditya Somasundaram

as7458 [at] columbia [dot] edu

How to train LLMs with small batch sizes?

Posted on November 1st 2025 • by Aditya Somasundaram

Everyone “knows” that to train large language models you need large batch sizes. Bigger batches mean less gradient noise, better training stability, lower final loss; so the bigger the batch, the better the model. Or so the folklore goes. In what follows, we'll see that this apparent (and false) necessity for large batches is, in fact, a consequence of how much history the optimizer retains. In reality, we find that small batches converge faster*, are easier to tune, and use less memory — all while maintaining performance.

* on a per-FLOP basis

This post distills findings from our NeurIPS paper, where we study the small batch regime in LLM pretraining.

Check out the paper  →

Most modern optimizers keep exponential moving averages (EMAs) while stepping through parameter space. Intuitively: the model's movement at the current point is weighted with what it has seen before. A scalar EMA for a stream \(x_t\) looks like

$$ y_t = \alpha\, y_{t-1} + (1 - \alpha)\, x_t $$

Where \(0 \le \alpha \le 1\) sets how quickly the past is forgotten. Smaller \(\alpha\) means you trust the latest observation more, and larger \(\alpha\) means a longer remembrance of history. This "remembrance" can also be expressed in terms of half-life (the number of steps after which the effect of a particular observation halves its contribution to the update).

Adam, commonly used to train language models, uses these EMAs too. In its simplest form,

$$ \begin{aligned} m_t &= \beta_1\, m_{t-1} + (1-\beta_1)\, g_t \\[8pt] v_t &= \beta_2\, v_{t-1} + (1-\beta_2)\, g_t^{\,2} \\[8pt] \theta_{t+1} &= \theta_t - \eta \,\frac{m_t}{\sqrt{v_t} + \varepsilon}\ \end{aligned} $$

Where \(g_t\) is the gradient at time step \(t\), \(\beta_1\) controls the momentum (first moment), \(\beta_2\) controls the preconditioning (second moment).

What do the \(\beta\)s do?

Think of each \(\beta\) as a forgetting factor on the previous estimate. For a simple EMA \(y_t = \alpha y_{t-1} + (1 - \alpha) x_t\), the weight on past data decays geometrically; the half-life \(H\) (in steps) satisfies \(\alpha^H=1/2\). In Adam's second moment equation \(v_t=\beta_2 v_{t-1}+(1-\beta_2)g_t^2\), the analogous relation is \(\beta_2^H=1/2\). So \(\beta_2\) directly sets “how many steps” of history your optimizer remembers. Shorter half-life means updates which react more towards incoming data; longer half-life means smoother updates which prioritizes information accumulated in the past.

Half-life effect of beta in Adam

Figure: Effect of \(\beta\) on exponential decay / half-life in Adam.

So far, we've talked about half-life in terms of steps. But in large-scale language model training, a step isn't a fixed unit. Each step processes \(B \times T\) tokens, where \(B\) is the batch size and \(T\) is the sequence length. That means a half-life measured in “steps” is actually a half-life measured in tokens consumed. If we keep \(\beta_2\) fixed while changing the batch size \(B\), then we covertly change how long the optimizer “remembers” in terms of tokens. Smaller batch sizes mean fewer tokens per step, so the same \(\beta_2\) corresponds to a shorter effective half-life in token units. In other words, holding \(\beta_2\) fixed (as is convention) causes the optimizer to “forget faster”.

How to set \(\beta_2\): half-life, batch size, and sequence length

Ideally, we would like our half-life (\(t_2\)) to be fixed across batch sizes \(B\) and sequence lengths \(T\), and therefore independent of them. We set \(\beta_2\) as follows.

$$ \boxed{ \displaystyle \beta_2 = 2^{-\frac{B \cdot T}{t_{2}}} } $$

This way, regardless of change in batch size or sequence length, the half-life of the second moment with respect to total tokens remains constant. Furthermore, from this arises a very nice scaling rule for \(\beta_2\) when one changes batch size:

$$ \boxed{ \displaystyle \beta_2^* = \beta_2 ^ {(B^* / B)} } $$

Example: Say a big lab puts out a pretraining recipe with \(B=512\), \(\beta_2 = 0.95\). When I can only fit \(B=1\) on my device, how should I adjust \(\beta_2\)?

$$ \beta_2^* = 0.95^{\frac{1}{512}} \approx 0.9999 $$

In our results below, we show that reframing and holding the half-life of \(\beta_2\) fixed allows for stable training, even down to batch size \(1\)!

Fixing \(\beta_2\) vs fixing half-life

Figure: Fixing half-life instead of \(\beta_2\) helps!

The Surprising Simplicity of Small Batch Optimization

To ground the discussion, we sweep optimizer hyperparameters across batch sizes and inspect how the landscape changes as we approach the small batch regime. The emerging pattern is consistent: small batches simplify optimization, while several familiar scaling rules bend or break.

  1. Scanning the Landscape: \(\beta_2, \beta_2\) and Learning Rate

    \(\beta_1\approx0.9\) is robust; and fixing the second moment half-life in tokens \(t_2\) is crucial. Across batch sizes, the optimal learning rate increases far more slowly than the conventional \(\sqrt{B}\) scaling rule. The standard choice \(\beta_1 = 0.9\) performs reliably across settings. Crucially, holding \(\beta_2\) fixed implicitly changes the effective half-life as batch size varies, which harms performance at small \(B\). Instead, fixing the token half-life \(t_2\) produces consistent behavior across batch sizes.

    Heatmaps

    Figure: We sweep over learning rate, \(\beta_1\), and \(\beta_2\) across all batch sizes to find the most optimal configuration. This was done on FineWeb-Edu for a 30M parameter model.

  2. Hyperparameter Sensitivity across Batch Sizes

    Small batches are more robust to hyper parameters. The sensitivity curves broaden dramatically as \(B\) tends to \(1\). At \(B=1\) you get a wide plateau of near-optimal settings over LR, \(\beta_1\) and the second moment half-life; at large \(B\) the loss rises quickly once you leave the tuned point. In practice this means fewer sweeps and easier hyper parameter tuning when you're running small batches.

    Heatmaps

    Figure: From the optimal configuration, we test the robustness of hyper parameters by independently changing each one. We observe that the robustness of hyper parameters improves as batch size decreases. The results are again on the 30M model trained on 600M tokens of FineWeb-Edu.

  3. Simple Optimizers Prove Sufficient

    Simple optimizers work when batches are small. With small \(B\), plain SGD (no momentum), Adafactor, Adam and Muon land at similar loss; as \(B\) grows, the gap between simple and sophisticated optimizers widens. The small batch regime lets you drop optimizer state without giving up performance, which is great for memory constrained runs.

    Optimizer vs Batch Size

    Figure: As batch size decreases, we see that all optimizers tend to perform equivalently. Small batch sizes are more robust to optimizer design.

Practical recommendations for the broke practitioner (aka. me)

Open Questions

While we do find that scaling \(\beta_2\) to ensure half-life is fixed helps stabilize training, it is admittedly surprising (at least personally) that \(\beta_1 = 0.9\) works well across batch sizes.

As Jeremy Bernstein and Laker Newhouse comment:

“EMA can then be thought of as "smoothing out" the algorithm, or making it more robust to mini-batch noise, although nailing down the precise role of EMA is perhaps still an open problem.”

There is something fundamentally different about preconditioning (second moment) compared to momentum (first moment), which gives rise to our discovery and optimization results. A deeper theoretical understanding of EMA and what each moment means would help us understand Adam, and by extension, optimization algorithms, better.

To Recap

Training LLMs is expensive. Model trainers usually use their hardware to the limit, making sure no resource goes underutilized. For most practitioners, GPU memory is the biggest bottleneck. Given that one has only a certain amount of GPU memory, one must wisely choose where to spend it on. Do I train a bigger model? Do I train with a large batch size? Do I use a more complex optimizer?

In our work, we show that the prevailing assumption that large batch sizes are a necessity to train LLMs is false. We show stable training with small batch sizes by holding the half-life of \(\beta_2\) in Adam fixed. We further find that small batch sizes are more robust to hyper parameter misspecifications, and as batch sizes get smaller, complex optimizers with large optimizer states become unnecessary.

This work was done in collaboration with truly fantastic colleagues. If it helps or inspires your own training efforts, please consider citing us!

@article{marek2025small,
      title={Small batch size training for language models: When vanilla SGD works, and why gradient accumulation is wasteful},
      author={Marek, Martin and Lotfi, Sanae and Somasundaram, Aditya and Wilson, Andrew Gordon and Goldblum, Micah},
      journal={arXiv preprint arXiv:2507.07101},
      year={2025}
    }